{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# !wget https://storage.googleapis.com/xlnet/released_models/cased_L-12_H-768_A-12.zip\n",
    "# !unzip cased_L-12_H-768_A-12.zip"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sentencepiece as spm\n",
    "from prepro_utils import preprocess_text, encode_ids\n",
    "\n",
    "sp_model = spm.SentencePieceProcessor()\n",
    "sp_model.Load('xlnet_cased_L-12_H-768_A-12/spiece.model')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "from prepro_utils import preprocess_text, encode_ids\n",
    "\n",
    "def tokenize_fn(text, sp_model):\n",
    "    text = preprocess_text(text, lower = False)\n",
    "    return encode_ids(sp_model, text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "SEG_ID_A   = 0\n",
    "SEG_ID_B   = 1\n",
    "SEG_ID_CLS = 2\n",
    "SEG_ID_SEP = 3\n",
    "SEG_ID_PAD = 4\n",
    "\n",
    "special_symbols = {\n",
    "    \"<unk>\"  : 0,\n",
    "    \"<s>\"    : 1,\n",
    "    \"</s>\"   : 2,\n",
    "    \"<cls>\"  : 3,\n",
    "    \"<sep>\"  : 4,\n",
    "    \"<pad>\"  : 5,\n",
    "    \"<mask>\" : 6,\n",
    "    \"<eod>\"  : 7,\n",
    "    \"<eop>\"  : 8,\n",
    "}\n",
    "\n",
    "VOCAB_SIZE = 32000\n",
    "UNK_ID = special_symbols[\"<unk>\"]\n",
    "CLS_ID = special_symbols[\"<cls>\"]\n",
    "SEP_ID = special_symbols[\"<sep>\"]\n",
    "MASK_ID = special_symbols[\"<mask>\"]\n",
    "EOD_ID = special_symbols[\"<eod>\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "import collections\n",
    "import re\n",
    "\n",
    "def get_assignment_map_from_checkpoint(tvars, init_checkpoint):\n",
    "    assignment_map = {}\n",
    "    initialized_variable_names = {}\n",
    "\n",
    "    name_to_variable = collections.OrderedDict()\n",
    "    for var in tvars:\n",
    "        name = var.name\n",
    "        m = re.match('^(.*):\\\\d+$', name)\n",
    "        if m is not None:\n",
    "            name = m.group(1)\n",
    "        name_to_variable[name] = var\n",
    "\n",
    "    init_vars = tf.train.list_variables(init_checkpoint)\n",
    "\n",
    "    assignment_map = collections.OrderedDict()\n",
    "    for x in init_vars:\n",
    "        (name, var) = (x[0], x[1])\n",
    "        if name not in name_to_variable:\n",
    "            continue\n",
    "        assignment_map[name] = name_to_variable[name]\n",
    "        initialized_variable_names[name] = 1\n",
    "        initialized_variable_names[name + ':0'] = 1\n",
    "\n",
    "    return (assignment_map, initialized_variable_names)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "import xlnet as xlnet_lib\n",
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "\n",
    "class Model:\n",
    "    def __init__(self, xlnet_config, tokenizer, checkpoint, pool_mode = 'last'):\n",
    "\n",
    "        kwargs = dict(\n",
    "            is_training = True,\n",
    "            use_tpu = False,\n",
    "            use_bfloat16 = False,\n",
    "            dropout = 0.0,\n",
    "            dropatt = 0.0,\n",
    "            init = 'normal',\n",
    "            init_range = 0.1,\n",
    "            init_std = 0.05,\n",
    "            clamp_len = -1,\n",
    "        )\n",
    "\n",
    "        xlnet_parameters = xlnet_lib.RunConfig(**kwargs)\n",
    "\n",
    "        self._tokenizer = tokenizer\n",
    "        _graph = tf.Graph()\n",
    "        with _graph.as_default():\n",
    "            self.X = tf.placeholder(tf.int32, [None, None])\n",
    "            self.segment_ids = tf.placeholder(tf.int32, [None, None])\n",
    "            self.input_masks = tf.placeholder(tf.float32, [None, None])\n",
    "\n",
    "            xlnet_model = xlnet_lib.XLNetModel(\n",
    "                xlnet_config = xlnet_config,\n",
    "                run_config = xlnet_parameters,\n",
    "                input_ids = tf.transpose(self.X, [1, 0]),\n",
    "                seg_ids = tf.transpose(self.segment_ids, [1, 0]),\n",
    "                input_mask = tf.transpose(self.input_masks, [1, 0]),\n",
    "            )\n",
    "\n",
    "            self.logits = xlnet_model.get_pooled_out(pool_mode, True)\n",
    "            self._sess = tf.InteractiveSession()\n",
    "            self._sess.run(tf.global_variables_initializer())\n",
    "            tvars = tf.trainable_variables()\n",
    "            assignment_map, _ = get_assignment_map_from_checkpoint(\n",
    "                tvars, checkpoint\n",
    "            )\n",
    "            self._saver = tf.train.Saver(var_list = assignment_map)\n",
    "            attentions = [\n",
    "                n.name\n",
    "                for n in tf.get_default_graph().as_graph_def().node\n",
    "                if 'rel_attn/Softmax' in n.name\n",
    "            ]\n",
    "            g = tf.get_default_graph()\n",
    "            self.attention_nodes = [\n",
    "                g.get_tensor_by_name('%s:0' % (a)) for a in attentions\n",
    "            ]\n",
    "\n",
    "    def vectorize(self, strings):\n",
    "        \"\"\"\n",
    "        Vectorize string inputs using bert attention.\n",
    "\n",
    "        Parameters\n",
    "        ----------\n",
    "        strings : str / list of str\n",
    "\n",
    "        Returns\n",
    "        -------\n",
    "        array: vectorized strings\n",
    "        \"\"\"\n",
    "\n",
    "        if isinstance(strings, list):\n",
    "            if not isinstance(strings[0], str):\n",
    "                raise ValueError('input must be a list of strings or a string')\n",
    "        else:\n",
    "            if not isinstance(strings, str):\n",
    "                raise ValueError('input must be a list of strings or a string')\n",
    "        if isinstance(strings, str):\n",
    "            strings = [strings]\n",
    "\n",
    "        input_ids, input_masks, segment_ids, _ = xlnet_tokenization(\n",
    "            self._tokenizer, strings\n",
    "        )\n",
    "        return self._sess.run(\n",
    "            self.logits,\n",
    "            feed_dict = {\n",
    "                self.X: input_ids,\n",
    "                self.segment_ids: segment_ids,\n",
    "                self.input_masks: input_masks,\n",
    "            },\n",
    "        )\n",
    "\n",
    "    def attention(self, strings, method = 'last', **kwargs):\n",
    "        \"\"\"\n",
    "        Get attention string inputs from xlnet attention.\n",
    "\n",
    "        Parameters\n",
    "        ----------\n",
    "        strings : str / list of str\n",
    "        method : str, optional (default='last')\n",
    "            Attention layer supported. Allowed values:\n",
    "\n",
    "            * ``'last'`` - attention from last layer.\n",
    "            * ``'first'`` - attention from first layer.\n",
    "            * ``'mean'`` - average attentions from all layers.\n",
    "\n",
    "        Returns\n",
    "        -------\n",
    "        array: attention\n",
    "        \"\"\"\n",
    "\n",
    "        if isinstance(strings, list):\n",
    "            if not isinstance(strings[0], str):\n",
    "                raise ValueError('input must be a list of strings or a string')\n",
    "        else:\n",
    "            if not isinstance(strings, str):\n",
    "                raise ValueError('input must be a list of strings or a string')\n",
    "        if isinstance(strings, str):\n",
    "            strings = [strings]\n",
    "\n",
    "        method = method.lower()\n",
    "        if method not in ['last', 'first', 'mean']:\n",
    "            raise Exception(\n",
    "                \"method not supported, only support ['last', 'first', 'mean']\"\n",
    "            )\n",
    "\n",
    "        input_ids, input_masks, segment_ids, s_tokens = xlnet_tokenization(\n",
    "            self._tokenizer, strings\n",
    "        )\n",
    "        maxlen = max([len(s) for s in s_tokens])\n",
    "        s_tokens = padding_sequence(s_tokens, maxlen, pad_int = '<cls>')\n",
    "        attentions = self._sess.run(\n",
    "            self.attention_nodes,\n",
    "            feed_dict = {\n",
    "                self.X: input_ids,\n",
    "                self.segment_ids: segment_ids,\n",
    "                self.input_masks: input_masks,\n",
    "            },\n",
    "        )\n",
    "\n",
    "        if method == 'first':\n",
    "            cls_attn = np.transpose(attentions[0][:, 0], (1, 0, 2))\n",
    "\n",
    "        if method == 'last':\n",
    "            cls_attn = np.transpose(attentions[-1][:, 0], (1, 0, 2))\n",
    "\n",
    "        if method == 'mean':\n",
    "            cls_attn = np.transpose(\n",
    "                np.mean(attentions, axis = 0).mean(axis = 1), (1, 0, 2)\n",
    "            )\n",
    "\n",
    "        cls_attn = np.mean(cls_attn, axis = 1)\n",
    "        total_weights = np.sum(cls_attn, axis = -1, keepdims = True)\n",
    "        attn = cls_attn / total_weights\n",
    "        output = []\n",
    "        for i in range(attn.shape[0]):\n",
    "            output.append(\n",
    "                merge_sentencepiece_tokens(list(zip(s_tokens[i], attn[i])))\n",
    "            )\n",
    "        return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def merge_sentencepiece_tokens(paired_tokens, weighted = True):\n",
    "    new_paired_tokens = []\n",
    "    n_tokens = len(paired_tokens)\n",
    "    rejected = ['<cls>', '<sep>']\n",
    "\n",
    "    i = 0\n",
    "\n",
    "    while i < n_tokens:\n",
    "\n",
    "        current_token, current_weight = paired_tokens[i]\n",
    "        if not current_token.startswith('▁') and current_token not in rejected:\n",
    "            previous_token, previous_weight = new_paired_tokens.pop()\n",
    "            merged_token = previous_token\n",
    "            merged_weight = [previous_weight]\n",
    "            while (\n",
    "                not current_token.startswith('▁')\n",
    "                and current_token not in rejected\n",
    "            ):\n",
    "                merged_token = merged_token + current_token.replace('▁', '')\n",
    "                merged_weight.append(current_weight)\n",
    "                i = i + 1\n",
    "                current_token, current_weight = paired_tokens[i]\n",
    "            merged_weight = np.mean(merged_weight)\n",
    "            new_paired_tokens.append((merged_token, merged_weight))\n",
    "\n",
    "        else:\n",
    "            new_paired_tokens.append((current_token, current_weight))\n",
    "            i = i + 1\n",
    "\n",
    "    words = [\n",
    "        i[0].replace('▁', '')\n",
    "        for i in new_paired_tokens\n",
    "        if i[0] not in ['<cls>', '<sep>', '<pad>']\n",
    "    ]\n",
    "    weights = [\n",
    "        i[1]\n",
    "        for i in new_paired_tokens\n",
    "        if i[0] not in ['<cls>', '<sep>', '<pad>']\n",
    "    ]\n",
    "    if weighted:\n",
    "        weights = np.array(weights)\n",
    "        weights = weights / np.sum(weights)\n",
    "    return list(zip(words, weights))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def padding_sequence(seq, maxlen, padding = 'post', pad_int = 0):\n",
    "    padded_seqs = []\n",
    "    for s in seq:\n",
    "        if padding == 'post':\n",
    "            padded_seqs.append(s + [pad_int] * (maxlen - len(s)))\n",
    "        if padding == 'pre':\n",
    "            padded_seqs.append([pad_int] * (maxlen - len(s)) + s)\n",
    "    return padded_seqs\n",
    "\n",
    "\n",
    "def xlnet_tokenization(tokenizer, texts):\n",
    "    input_ids, input_masks, segment_ids, s_tokens = [], [], [], []\n",
    "    for text in texts:\n",
    "        tokens_a = tokenize_fn(text, tokenizer)\n",
    "        tokens = []\n",
    "        segment_id = []\n",
    "        for token in tokens_a:\n",
    "            tokens.append(token)\n",
    "            segment_id.append(SEG_ID_A)\n",
    "\n",
    "        tokens.append(SEP_ID)\n",
    "        segment_id.append(SEG_ID_A)\n",
    "        tokens.append(CLS_ID)\n",
    "        segment_id.append(SEG_ID_CLS)\n",
    "\n",
    "        input_id = tokens\n",
    "        input_mask = [0] * len(input_id)\n",
    "\n",
    "        input_ids.append(input_id)\n",
    "        input_masks.append(input_mask)\n",
    "        segment_ids.append(segment_id)\n",
    "        s_tokens.append([tokenizer.IdToPiece(i) for i in tokens])\n",
    "\n",
    "    maxlen = max([len(i) for i in input_ids])\n",
    "    input_ids = padding_sequence(input_ids, maxlen, padding = 'pre')\n",
    "    input_masks = padding_sequence(\n",
    "        input_masks, maxlen, padding = 'pre', pad_int = 1\n",
    "    )\n",
    "    segment_ids = padding_sequence(\n",
    "        segment_ids, maxlen, padding = 'pre', pad_int = SEG_ID_PAD\n",
    "    )\n",
    "\n",
    "    return input_ids, input_masks, segment_ids, s_tokens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.6/dist-packages/tensorflow/python/client/session.py:1735: UserWarning: An interactive session is already active. This can cause out-of-memory errors in some cases. You must explicitly call `InteractiveSession.close()` to release resources held by the other session(s).\n",
      "  warnings.warn('An interactive session is already active. This can '\n",
      "W0806 00:38:22.101047 140372522661696 deprecation.py:323] From /usr/local/lib/python3.6/dist-packages/tensorflow/python/training/saver.py:1276: checkpoint_exists (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use standard file APIs to check for files with this prefix.\n"
     ]
    }
   ],
   "source": [
    "xlnet_config = xlnet_lib.XLNetConfig(json_path = 'xlnet_cased_L-12_H-768_A-12/xlnet_config.json')\n",
    "xlnet_checkpoint = 'xlnet_cased_L-12_H-768_A-12/xlnet_model.ckpt'\n",
    "model = Model(\n",
    "xlnet_config, sp_model, xlnet_checkpoint, pool_mode = 'last'\n",
    ")\n",
    "model._saver.restore(model._sess, xlnet_checkpoint)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1, 768)"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.vectorize('i love u').shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[('i', 0.2029094),\n",
       "  ('love', 0.17279622),\n",
       "  ('u', 0.21883217),\n",
       "  ('sooo', 0.19822793),\n",
       "  ('much!', 0.20723426)]]"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.attention('i love u sooo much!', method = 'first')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[('i', 0.5867717),\n",
       "  ('love', 0.08966514),\n",
       "  ('u', 0.23113684),\n",
       "  ('sooo', 0.057987895),\n",
       "  ('much!', 0.03443839)]]"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.attention('i love u sooo much!', method = 'last')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[('i', 0.22920714),\n",
       "  ('love', 0.2319069),\n",
       "  ('u', 0.18129185),\n",
       "  ('sooo', 0.13547097),\n",
       "  ('much!', 0.22212313)]]"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.attention('i love u sooo much!', method = 'mean')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
